{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2023 Google LLC\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Identifying Relevant Pages in a PDF using Gemini 2.0\n",
    "\n",
    "The goal of this notebook is to extract specific information from a large PDF by using Gemini to identify relevant pages and create a new, focused PDF.\n",
    "\n",
    "In this notebook, you will:\n",
    " - Use Gemini to identify pages in a large PDF that contain information about a given question.\n",
    " - Extract and compile the identified pages into a new PDF.\n",
    " - Save the PDF to a file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install python packages\n",
    "! pip install -U pypdf\n",
    "! pip install -U google-cloud-aiplatform\n",
    "! pip install -U pdf2image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import all the required python packages\n",
    "import io\n",
    "import json\n",
    "import pypdf\n",
    "import vertexai\n",
    "\n",
    "from pdf2image import convert_from_bytes\n",
    "from IPython.display import display\n",
    "from typing import Iterable\n",
    "\n",
    "from vertexai.preview.generative_models import (\n",
    "    GenerationResponse,\n",
    "    GenerativeModel,\n",
    "    HarmBlockThreshold,\n",
    "    HarmCategory,\n",
    "    Part\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Include information about your project in the next cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "PROJECT_ID = \"[your-project-id]\"  # Replace with your project ID\n",
    "LOCATION = \"us-central1\"  # Replace with your location\n",
    "MODEL_NAME = \"gemini-2.0-flash-001\"\n",
    "\n",
    "vertexai.init(project=PROJECT_ID, location=LOCATION)\n",
    "model = GenerativeModel(MODEL_NAME)\n",
    "BLOCK_LEVEL = HarmBlockThreshold.BLOCK_ONLY_HIGH"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following is the prompt used to extract the pages related to the question."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "PROMPT_PAGES = \"\"\"\n",
    "Return the numbers of all pages in the document above that contain information related to the question below.\n",
    "<Instructions>\n",
    " - Use the document above as your only source of information to determine which pages are related to the question below.\n",
    " - Return the page numbers of the document above that are related to the question. When in doubt, return the page anyway.\n",
    " - The response should be a JSON list, as shown in the example below.\n",
    "</Instructions>\n",
    "<Suggestions>\n",
    " - The document above is a financial report with various tables, charts, infographics, lists, and additional text information.\n",
    " - Pay CLOSE ATTENTION to the chart legends and chart COLORS to determine the pages. Colors may indicate which information is important for determining the pages.\n",
    " - The color of the chart legends represents the color of the bars in the chart.\n",
    " - Use ONLY this document as context to determine the pages.\n",
    " - In most cases, the page number can be found in the footer.\n",
    "</Suggestions>\n",
    "<Question>\n",
    "{question}\n",
    "</Question>\n",
    "<Example JSON Output>\n",
    "{{\n",
    "  \"pages\": [1, 2, 3, 4, 5]\n",
    "}}\n",
    "</Example JSON Output>\n",
    "json:\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pdf_cut(pdf_bytes: bytes, pages: list[int]) -> bytes:\n",
    "    \"\"\"Using the pdf bytes and a list of page numbers,\n",
    "    return the pdf bytes of a new pdf with only those pages\n",
    "    Args:\n",
    "        pdf_bytes:\n",
    "            Bytes of a pdf file\n",
    "        pages:\n",
    "            List of page numbers to extract from the pdf bytes\n",
    "    Returns:\n",
    "        Bytes of a new pdf with only the extracted pages\n",
    "    \"\"\"\n",
    "    pdf_reader = pypdf.PdfReader(io.BytesIO(pdf_bytes))\n",
    "    pdf_writer = pypdf.PdfWriter()\n",
    "    for page in pages:\n",
    "        try:\n",
    "            pdf_writer.add_page(pdf_reader.pages[page - 1])\n",
    "        except Exception as e:\n",
    "            pass\n",
    "    output = io.BytesIO()\n",
    "    pdf_writer.write(output)\n",
    "    return output.getvalue()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate(\n",
    "    prompt: list,\n",
    "    max_output_tokens: int = 2048,\n",
    "    temperature: int = 2,\n",
    "    top_p: float = 0.4,\n",
    "    stream: bool = False,\n",
    ") -> GenerationResponse | Iterable[GenerationResponse]:\n",
    "    \"\"\"\n",
    "    Function to generate response using Gemini 2.0\n",
    "\n",
    "    Args:\n",
    "        prompt:\n",
    "            List of prompt parts\n",
    "        max_output_tokens:\n",
    "            Max Output tokens\n",
    "        temperature:\n",
    "            Temperature for the model\n",
    "        top_p:\n",
    "            Top-p for the model\n",
    "        stream:\n",
    "            Strem results?\n",
    "\n",
    "    Returns:\n",
    "        Model response\n",
    "\n",
    "    \"\"\"\n",
    "    responses = model.generate_content(\n",
    "        prompt,\n",
    "        generation_config={\n",
    "            \"max_output_tokens\": max_output_tokens,\n",
    "            \"temperature\": temperature,\n",
    "            \"top_p\": top_p,\n",
    "        },\n",
    "        safety_settings={\n",
    "            HarmCategory.HARM_CATEGORY_HATE_SPEECH: BLOCK_LEVEL,\n",
    "            HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: BLOCK_LEVEL,\n",
    "            HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: BLOCK_LEVEL,\n",
    "            HarmCategory.HARM_CATEGORY_HARASSMENT: BLOCK_LEVEL,\n",
    "        },\n",
    "        stream=stream,\n",
    "    )\n",
    "\n",
    "    return responses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pdf_pages(\n",
    "    question: str, \n",
    "    pdf_bytes: bytes, \n",
    "    instructions_prompt: str = PROMPT_PAGES\n",
    ") -> list[int]:\n",
    "    \"\"\"\n",
    "    Function to generate a list of page numbers with pdf bytes and a question\n",
    "\n",
    "    Args:\n",
    "        question:\n",
    "            Question to ask the model\n",
    "        pdf_bytes:\n",
    "            PDF bytes\n",
    "        instructions_prompt:\n",
    "            Prompt for the model\n",
    "\n",
    "    Returns:\n",
    "        List of page numbers\n",
    "    \"\"\"\n",
    "    pdf_document = Part.from_data(data=pdf_bytes, mime_type=\"application/pdf\")\n",
    "    prompt = [\n",
    "        \"<Document>\",\n",
    "        pdf_document,\n",
    "        \"</Document>\",\n",
    "        instructions_prompt.format(question=question),\n",
    "    ]\n",
    "    responses = generate(prompt=prompt)\n",
    "\n",
    "    if isinstance(responses, GenerationResponse):\n",
    "        output_json = json.loads(responses.text)\n",
    "    else:\n",
    "        output_json = json.loads(\n",
    "            \" \".join([response.text for response in responses])\n",
    "        )\n",
    "    return output_json[\"pages\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the next cell, include information about your question and the pdf_path.  \n",
    "\n",
    "**(Optional)**  \n",
    "If you are using Colab to test this notebook, you can try the following code to upload your PDF files.  \n",
    "```python\n",
    "from google.colab import files\n",
    "files.upload()\n",
    "```\n",
    "\n",
    "You can uncomment the code in the cell to use this method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from google.colab import files\n",
    "# files.upload()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Include your question and the path to your PDF\n",
    "# question = \"What are the key trends for financial services industry?\"\n",
    "question = \"From the Consolidated Balance Sheet, what was the difference between the total assets from 2022 to 2023?\"\n",
    "pdf_path = \"./Cymbal Bank - Financial Statements.pdf\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[9]\n"
     ]
    }
   ],
   "source": [
    "# Open the file, extract the pages using Gemini 2.0 and print them\n",
    "with open(pdf_path, \"rb\") as f:\n",
    "    pdf_bytes = f.read()\n",
    "pages = pdf_pages(question=question, pdf_bytes=pdf_bytes)\n",
    "print(pages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To ensure we find the answer to the question, it will also retrieve the page immediately after those.\n",
    "expanded_pages = set(pages)\n",
    "expanded_pages.update({i+1 for i in pages})\n",
    "new_pdf = pdf_cut(pdf_bytes=pdf_bytes, pages=list(expanded_pages))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Write the result to a new PDF document\n",
    "with open(\"./sample.pdf\", \"wb\") as fp:\n",
    "    fp.write(new_pdf)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### (Optional) Print the PDF pages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "images = convert_from_bytes(new_pdf)\n",
    "for i, image in enumerate(images):\n",
    "    display(image)"
   ]
  }
 ],
 "metadata": {
  "environment": {
   "kernel": "python3",
   "name": "workbench-notebooks.m128",
   "type": "gcloud",
   "uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/workbench-notebooks:m128"
  },
  "kernelspec": {
   "display_name": "py311",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
